package server

import (
	"crypto/tls"
	"crypto/x509"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/credentials/insecure"
	"io/ioutil"
	"log"
	"strconv"
)

type GrpcClient struct {
	Host           string
	Port           int
	TlsPublicPath  string
	TlsPrivatePath string
	CaCertPath     string
	conn           *grpc.ClientConn
}

func (grpcClient *GrpcClient) GetConn() (*grpc.ClientConn, error) {
	if grpcClient.conn != nil {
		return grpcClient.conn, nil
	}
	if grpcClient.TlsPrivatePath != "" && grpcClient.TlsPublicPath != "" && grpcClient.CaCertPath != "" {
		cert, err := tls.LoadX509KeyPair(grpcClient.TlsPublicPath, grpcClient.TlsPublicPath)
		if err != nil {
			log.Fatalf("Failed to load client certificate and key: %v", err)
			return nil, err
		}

		// Load the server CA certificate
		caCert, err := ioutil.ReadFile(grpcClient.CaCertPath)
		if err != nil {
			log.Fatalf("Failed to load server CA certificate: %v", err)
			return nil, err
		}
		caCertPool := x509.NewCertPool()
		caCertPool.AppendCertsFromPEM(caCert)

		// Create the TLS credentials for the client
		creds := credentials.NewTLS(&tls.Config{
			Certificates: []tls.Certificate{cert},
			RootCAs:      caCertPool,
		})

		// Set up a connection to the server
		conn, err := grpc.Dial("server.domain.com:50051", grpc.WithTransportCredentials(creds))
		if err != nil {
			log.Fatalf("Failed to dial server: %v", err)
			return nil, err
		}
		if err == nil {
			grpcClient.conn = conn
		}
		return conn, err
	}
	conn, err := grpc.Dial(
		grpcClient.Host+":"+strconv.Itoa(grpcClient.Port),
		grpc.WithTransportCredentials(insecure.NewCredentials()))
	if err == nil {
		grpcClient.conn = conn
	}
	return conn, err
}

func (grpcClient *GrpcClient) GetGrpcClient(initClientFun func(conn *grpc.ClientConn) (interface{}, error)) (interface{}, error) {
	conn, err := grpcClient.GetConn()
	if err != nil {
		return nil, err
	}
	return initClientFun(conn)
}
